"""
Borrowed from https://github.com/WujiangXu/AgenticMemory/blob/main/utils.py

@article{xu2025mem,
    title={A-mem: Agentic memory for llm agents},
    author={Xu, Wujiang and Liang, Zujie and Mei, Kai and Gao, Hang and Tan, Juntao
           and Zhang, Yongfeng},
    journal={arXiv preprint arXiv:2502.12110},
    year={2025}
}
"""

# import statistics
# from collections import defaultdict
from typing import Dict, List, Union

import nltk
import os
# from bert_score import score as bert_score
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu

# Ensure NLTK data is downloaded
def ensure_nltk_punkt():
    """Ensure punkt tokenizer is available"""
    try:
        nltk.data.find('tokenizers/punkt')
    except LookupError:
        try:
            print("Downloading NLTK punkt tokenizer...")
            nltk.download('punkt', quiet=True)
        except Exception as e:
            print(f"Failed to download punkt: {e}")
            # Try with explicit path
            try:
                nltk.download('punkt', download_dir=os.path.expanduser('~/nltk_data'))
            except Exception as e2:
                print(f"Fallback download failed: {e2}")

# Initialize punkt tokenizer
ensure_nltk_punkt()
# from nltk.translate.meteor_score import meteor_score
# from rouge_score import rouge_scorer
# from sentence_transformers import SentenceTransformer

# from load_dataset import load_locomo_dataset, QA, Turn, Session, Conversation
# from sentence_transformers.util import pytorch_cos_sim

# # Download required NLTK data
# try:
#     nltk.download("punkt", quiet=True)
#     nltk.download("wordnet", quiet=True)
# except Exception as e:
#     print(f"Error downloading NLTK data: {e}")

# # Initialize SentenceTransformer model (this will be reused)
# try:
#     sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
# except Exception as e:
#     print(f"Warning: Could not load SentenceTransformer model: {e}")
#     sentence_model = None


def simple_tokenize(text):
    """Simple tokenization function."""
    # Convert to string if not already
    text = str(text)
    return text.lower().replace(".", " ").replace(",", " ").replace("!", " ").replace("?", " ").split()


# def calculate_rouge_scores(prediction: str, reference: str) -> Dict[str, float]:
#     """Calculate ROUGE scores for prediction against reference."""
#     scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
#     scores = scorer.score(reference, prediction)
#     return {
#         "rouge1_f": scores["rouge1"].fmeasure,
#         "rouge2_f": scores["rouge2"].fmeasure,
#         "rougeL_f": scores["rougeL"].fmeasure,
#     }


def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]:
    """Calculate BLEU scores with different n-gram settings."""
    try:
        pred_tokens = nltk.word_tokenize(prediction.lower())
        ref_tokens = [nltk.word_tokenize(reference.lower())]
    except Exception as e:
        print(f"NLTK tokenization failed: {e}, using simple tokenization")
        # Fallback to simple tokenization
        pred_tokens = simple_tokenize(prediction)
        ref_tokens = [simple_tokenize(reference)]

    weights_list = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.33, 0.33, 0.33, 0), (0.25, 0.25, 0.25, 0.25)]
    smooth = SmoothingFunction().method1

    scores = {}
    for n, weights in enumerate(weights_list, start=1):
        try:
            score = sentence_bleu(ref_tokens, pred_tokens, weights=weights, smoothing_function=smooth)
        except Exception as e:
            print(f"Error calculating BLEU score: {e}")
            score = 0.0
        scores[f"bleu{n}"] = score

    return scores


# def calculate_bert_scores(prediction: str, reference: str) -> Dict[str, float]:
#     """Calculate BERTScore for semantic similarity."""
#     try:
#         P, R, F1 = bert_score([prediction], [reference], lang="en", verbose=False)
#         return {"bert_precision": P.item(), "bert_recall": R.item(), "bert_f1": F1.item()}
#     except Exception as e:
#         print(f"Error calculating BERTScore: {e}")
#         return {"bert_precision": 0.0, "bert_recall": 0.0, "bert_f1": 0.0}


# def calculate_meteor_score(prediction: str, reference: str) -> float:
#     """Calculate METEOR score for the prediction."""
#     try:
#         return meteor_score([reference.split()], prediction.split())
#     except Exception as e:
#         print(f"Error calculating METEOR score: {e}")
#         return 0.0


# def calculate_sentence_similarity(prediction: str, reference: str) -> float:
#     """Calculate sentence embedding similarity using SentenceBERT."""
#     if sentence_model is None:
#         return 0.0
#     try:
#         # Encode sentences
#         embedding1 = sentence_model.encode([prediction], convert_to_tensor=True)
#         embedding2 = sentence_model.encode([reference], convert_to_tensor=True)

#         # Calculate cosine similarity
#         similarity = pytorch_cos_sim(embedding1, embedding2).item()
#         return float(similarity)
#     except Exception as e:
#         print(f"Error calculating sentence similarity: {e}")
#         return 0.0


def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
    """Calculate comprehensive evaluation metrics for a prediction."""
    # Handle empty or None values
    if not prediction or not reference:
        return {
            "exact_match": 0,
            "f1": 0.0,
            "rouge1_f": 0.0,
            "rouge2_f": 0.0,
            "rougeL_f": 0.0,
            "bleu1": 0.0,
            "bleu2": 0.0,
            "bleu3": 0.0,
            "bleu4": 0.0,
            "bert_f1": 0.0,
            "meteor": 0.0,
            "sbert_similarity": 0.0,
        }

    # Convert to strings if they're not already
    prediction = str(prediction).strip()
    reference = str(reference).strip()

    # Calculate exact match
    exact_match = int(prediction.lower() == reference.lower())

    # Calculate token-based F1 score
    pred_tokens = set(simple_tokenize(prediction))
    ref_tokens = set(simple_tokenize(reference))
    common_tokens = pred_tokens & ref_tokens

    if not pred_tokens or not ref_tokens:
        f1 = 0.0
    else:
        precision = len(common_tokens) / len(pred_tokens)
        recall = len(common_tokens) / len(ref_tokens)
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

    # Calculate all scores
    bleu_scores = calculate_bleu_scores(prediction, reference)

    # Combine all metrics
    metrics = {
        "exact_match": exact_match,
        "f1": f1,
        **bleu_scores,
    }

    return metrics


# def aggregate_metrics(
#     all_metrics: List[Dict[str, float]], all_categories: List[int]
# ) -> Dict[str, Dict[str, Union[float, Dict[str, float]]]]:
#     """Calculate aggregate statistics for all metrics, split by category."""
#     if not all_metrics:
#         return {}

#     # Initialize aggregates for overall and per-category metrics
#     aggregates = defaultdict(list)
#     category_aggregates = defaultdict(lambda: defaultdict(list))

#     # Collect all values for each metric, both overall and per category
#     for metrics, category in zip(all_metrics, all_categories):
#         for metric_name, value in metrics.items():
#             aggregates[metric_name].append(value)
#             category_aggregates[category][metric_name].append(value)

#     # Calculate statistics for overall metrics
#     results = {"overall": {}}

#     for metric_name, values in aggregates.items():
#         results["overall"][metric_name] = {
#             "mean": statistics.mean(values),
#             "std": statistics.stdev(values) if len(values) > 1 else 0.0,
#             "median": statistics.median(values),
#             "min": min(values),
#             "max": max(values),
#             "count": len(values),
#         }

#     # Calculate statistics for each category
#     for category in sorted(category_aggregates.keys()):
#         results[f"category_{category}"] = {}
#         for metric_name, values in category_aggregates[category].items():
#             if values:  # Only calculate if we have values for this category
#                 results[f"category_{category}"][metric_name] = {
#                     "mean": statistics.mean(values),
#                     "std": statistics.stdev(values) if len(values) > 1 else 0.0,
#                     "median": statistics.median(values),
#                     "min": min(values),
#                     "max": max(values),
#                     "count": len(values),
#                 }

    # return results
